from __future__ import annotations

import abc
from abc import abstractmethod
from typing import TYPE_CHECKING, Any, Callable, Protocol, runtime_checkable

from litestar.enums import ScopeType
from litestar.middleware._utils import (
    build_exclude_path_pattern,
    should_bypass_middleware,
)
from litestar.utils.deprecation import warn_deprecation

__all__ = (
    "ASGIMiddleware",
    "AbstractMiddleware",
    "DefineMiddleware",
    "MiddlewareProtocol",
)

if TYPE_CHECKING:
    from litestar.middleware.constraints import MiddlewareConstraints
    from litestar.types import RouteHandlerType, Scopes
    from litestar.types.asgi_types import ASGIApp, Receive, Scope, Send


@runtime_checkable
class MiddlewareProtocol(Protocol):
    """Abstract middleware protocol."""

    __slots__ = ("app",)

    app: ASGIApp

    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        """Execute the ASGI middleware.

        Called by the previous middleware in the stack if a response is not awaited prior.

        Upon completion, middleware should call the next ASGI handler and await it - or await a response created in its
        closure.

        Args:
            scope: The ASGI connection scope.
            receive: The ASGI receive function.
            send: The ASGI send function.

        Returns:
            None
        """


class DefineMiddleware:
    """Container enabling passing ``*args`` and ``**kwargs`` to Middleware class constructors and factory functions."""

    __slots__ = ("args", "kwargs", "middleware")

    def __init__(self, middleware: Callable[..., ASGIApp], *args: Any, **kwargs: Any) -> None:
        """Initialize ``DefineMiddleware``.

        Args:
            middleware: A callable that returns an ASGIApp.
            *args: Positional arguments to pass to the callable.
            **kwargs: Key word arguments to pass to the callable.

        Notes:
            The callable will be passed a kwarg ``app``, which is the next ASGI app to call in the middleware stack.
            It therefore must define such a kwarg.
        """
        self.middleware = middleware
        self.args = args
        self.kwargs = kwargs

    def __call__(self, app: ASGIApp) -> ASGIApp:
        """Call the middleware constructor or factory.

        Args:
            app: An ASGIApp, this value is the next ASGI handler to call in the middleware stack.

        Returns:
            Calls :class:`DefineMiddleware.middleware <.DefineMiddleware>` and returns the ASGIApp created.
        """

        return self.middleware(*self.args, app=app, **self.kwargs)


class AbstractMiddleware:
    """Abstract middleware providing base functionality common to all middlewares, for dynamically engaging/bypassing
    the middleware based on paths, ``opt``-keys and scope types.

    When implementing new middleware, this class should be used as a base.
    """

    scopes: Scopes = {ScopeType.HTTP, ScopeType.WEBSOCKET}
    exclude: str | list[str] | None = None
    exclude_opt_key: str | None = None

    def __init__(
        self,
        app: ASGIApp,
        exclude: str | list[str] | None = None,
        exclude_opt_key: str | None = None,
        scopes: Scopes | None = None,
    ) -> None:
        """Initialize the middleware.

        Args:
            app: The ``next`` ASGI app to call.
            exclude: A pattern or list of patterns to match against a request's path.
                If a match is found, the middleware will be skipped.
            exclude_opt_key: An identifier that is set in the route handler
                ``opt`` key which allows skipping the middleware.
            scopes: ASGI scope types, should be a set including
                either or both 'ScopeType.HTTP' and 'ScopeType.WEBSOCKET'.
        """
        self.app = app
        self.scopes = scopes or self.scopes
        self.exclude_opt_key = exclude_opt_key or self.exclude_opt_key
        self.exclude_pattern = build_exclude_path_pattern(exclude=(exclude or self.exclude), middleware_cls=type(self))

    @classmethod
    def __init_subclass__(cls, **kwargs: Any) -> None:
        if not any(c.__module__.startswith("litestar") and c is not AbstractMiddleware for c in cls.mro()):
            # we don't want to warn about usage of 'AbstractMiddleware' if users aren't
            # directly subclassing it, i.e. they're subclassing another Litestar
            # middleware which itself subclasses 'AbstractMiddleware'
            warn_deprecation(
                version="2.15",
                deprecated_name="AbstractMiddleware",
                kind="class",
                alternative="litestar.middleware.ASGIMiddleware",
            )

        super().__init_subclass__(**kwargs)

        original__call__ = cls.__call__

        async def wrapped_call(self: AbstractMiddleware, scope: Scope, receive: Receive, send: Send) -> None:
            if should_bypass_middleware(
                scope=scope,
                scopes=self.scopes,
                exclude_path_pattern=self.exclude_pattern,
                exclude_opt_key=self.exclude_opt_key,
            ):
                await self.app(scope, receive, send)
            else:
                await original__call__(self, scope, receive, send)  # pyright: ignore

        # https://github.com/python/mypy/issues/2427#issuecomment-384229898
        setattr(cls, "__call__", wrapped_call)

    @abstractmethod
    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        """Execute the ASGI middleware.

        Called by the previous middleware in the stack if a response is not awaited prior.

        Upon completion, middleware should call the next ASGI handler and await it - or await a response created in its
        closure.

        Args:
            scope: The ASGI connection scope.
            receive: The ASGI receive function.
            send: The ASGI send function.

        Returns:
            None
        """
        raise NotImplementedError("abstract method must be implemented")


class ASGIMiddleware(abc.ABC):
    """An abstract base class to easily construct ASGI middlewares, providing
    functionality to dynamically skip the middleware based on ASGI ``scope["type"]``,
    handler ``opt`` keys or path patterns and a simple way to pass configuration to
    middlewares.

    This base class does not implement an ``__init__`` method, so subclasses are free
    to use it to customize the middleware's configuration.

    .. important::
        An instance of the individual middlewares will be created *once* and used to
        build up the internal middleware stack. As such, middlewares should *not* be
        stateful, as this state will be shared across all requests.
        Any connection-specific state should be scoped to the `handle` implementation.
        Not doing so would typically lead to conflicting variable reads / writes across
        requests, and - most likely - bugs.


    .. code-block:: python

        class MyMiddleware(ASGIMiddleware):
            scopes = (ScopeType.HTTP,)
            exclude_path_pattern = ("/not/this/path",)
            exclude_opt_key = "exclude_my_middleware"

            def __init__(self, my_logger: Logger) -> None:
                self.logger = my_logger

            async def handle(
                self, scope: Scope, receive: Receive, send: Send, next_app: ASGIApp
            ) -> None:
                self.logger.debug("Received request for path %s", scope["path"])
                await next_app(scope, receive, send)
                self.logger.debug("Processed request for path %s", scope["path"])


        app = Litestar(..., middleware=[MyMiddleware(logger=my_logger)])

    .. versionadded:: 2.15
    """

    scopes: tuple[ScopeType, ...] = (
        ScopeType.HTTP,
        ScopeType.WEBSOCKET,
        ScopeType.ASGI,
    )
    """Scope types this middleware should be applied to"""
    exclude_path_pattern: str | tuple[str, ...] | None = None
    r"""
    A regex pattern (or tuple of patterns) to exclude this middleware from route
    handlers whose path matches any of the provided patterns.

    .. important::
        Pattern matching is performed against the **handler's path** (e.g.,
        ``/user/{user_id:int}/``), NOT against the actual **request path** (e.g.,
        ``/user/1234/``). This is a critical distinction for dynamic routes.

        If you need to exclude based on paths dynamically, use
        :attr:`~litestar.middleware.ASGIMiddleware.should_bypass_for_scope`
        instead, matching on ``scope["path"]``.

    **Example 1: Static path**

    Handler path::

        /api/health

    To exclude this handler, use a pattern like::

        exclude_path_pattern = r"^/api/health$"

    **Example 2: Dynamic path (path parameters)**

    Handler path::

        /user/{user_id:int}/profile
             └─────┬──────┘
                   └─ This is what the pattern matches against

    Actual request paths that match this handler::

        /user/1234/profile
        /user/5678/profile
        /user/9999/profile

    To exclude this handler, the pattern must match the **handler**, not the actual request path::

        exclude_path_pattern = "/user/{user_id:int}/profile"
        exclude_path_pattern = "/user/\{.+?\}/"
    """
    exclude_opt_key: str | None = None
    """
    Exclude this middleware for handlers with an opt-key of this name that is truthy
    """
    should_bypass_for_scope: Callable[[Scope], bool] | None = None
    r"""
    A callable that takes in the :class:`~litestar.types.Scope` of the current
    connection and returns a boolean, indicating if the middleware should be skipped for
    the current request.

    This can for example be used to exclude a middleware based on a dynamic path::

        should_bypass_for_scope = lambda scope: scope["path"].endswith(".jpg")

    Applied to a route with a dynamic path like ``/static/{file_name:str}``, it would
    be skipped *only* if ``file_name`` has a ``.jpg`` extension.

    .. note::

        If it is not required to dynamically match the path of a request,
        :attr:`~litestar.middleware.ASGIMiddleware.exclude_path_pattern` should be
        used instead. Since its exclusion is done statically at startup time, it has no
        performance cost at runtime.

    .. versionadded:: 3.0
    """
    constraints: MiddlewareConstraints | None = None

    def should_bypass_for_handler(self, handler: RouteHandlerType) -> bool:
        """Return ``True`` if this middleware should be bypassed for ``handler``, according
        to :attr:`~litestar.middleware.ASGIMiddleware.scopes`,
        :attr:`~litestar.middleware.ASGIMiddleware.exclude_path_pattern` or
        :attr:`~litestar.middleware.ASGIMiddleware.exclude_opt_key`, otherwise ``False``.
        """
        from litestar.handlers import ASGIRouteHandler, HTTPRouteHandler, WebsocketRouteHandler

        if isinstance(handler, HTTPRouteHandler) and ScopeType.HTTP not in self.scopes:
            return True
        if isinstance(handler, WebsocketRouteHandler) and ScopeType.WEBSOCKET not in self.scopes:
            return True
        if isinstance(handler, ASGIRouteHandler) and ScopeType.ASGI not in self.scopes:
            return True

        if self.exclude_opt_key and handler.opt.get(self.exclude_opt_key):
            return True

        pattern = build_exclude_path_pattern(exclude=self.exclude_path_pattern, middleware_cls=type(self))
        if pattern and any(pattern.search(path) for path in handler.paths):
            return True

        return False

    def __call__(self, app: ASGIApp) -> ASGIApp:
        """Create the actual middleware callable"""
        handle = self.handle

        should_bypass_for_scope = self.should_bypass_for_scope
        if should_bypass_for_scope is None:

            async def middleware(scope: Scope, receive: Receive, send: Send) -> None:
                await handle(scope=scope, receive=receive, send=send, next_app=app)
        else:

            async def middleware(scope: Scope, receive: Receive, send: Send) -> None:
                if should_bypass_for_scope(scope):
                    await app(scope, receive, send)
                else:
                    await handle(scope=scope, receive=receive, send=send, next_app=app)

        return middleware

    @abc.abstractmethod
    async def handle(self, scope: Scope, receive: Receive, send: Send, next_app: ASGIApp) -> None:
        """Handle ASGI call.

        Args:
            scope: The ASGI connection scope.
            receive: The ASGI receive function.
            send: The ASGI send function
            next_app: The next ASGI application in the middleware stack to call
        """

        raise NotImplementedError
